import os
import random
from torch.utils.data import Dataset


def get_objects(dataset):
    if dataset == 'mvtec':
        return [
            "bottle",
            "cable",
            "capsule",
            "carpet",
            "grid",
            "hazelnut",
            "leather",
            "metal_nut",
            "pill",
            "screw",
            "tile",
            "toothbrush",
            "transistor",
            "wood",
            "zipper"
        ]
    elif dataset == 'visa':
        return [
            "candle",
            "capsules",
            "cashew",
            "chewinggum",
            "fryum",
            "macaroni1",
            "macaroni2",
            "pcb1",
            "pcb2",
            "pcb3",
            "pcb4",
            "pipe_fryum"
        ]
    

object_dictionary = {
    "chewinggum": "chewing_gum",
    "macaroni1": "macaroni",
    "macaroni2": "macaroni",
    "pcb1": "printed_circuit_board",
    "pcb2": "printed_circuit_board",
    "pcb3": "printed_circuit_board",
    "pcb4": "printed_circuit_board",
}


class CustomDataset(Dataset):
    def __init__(self, dataset, object, base_dir, shot=0):
        base_dir = os.path.join(base_dir, object)
        self.shot = shot
        self.data = []

        # train dataset
        if self.shot > 0:
            image_dir = os.path.join(base_dir, "train")
            file_names = sorted(os.listdir(os.path.join(image_dir, "good")))
            file_names = random.sample(file_names, self.shot)

            for file_name in file_names:
                image_path = os.path.join(image_dir, "good", file_name)
                self.data.append(image_path)
        # test dataset
        else:
            image_dir = os.path.join(base_dir, "test")
            mask_dir = os.path.join(base_dir, "ground_truth")
            classes = sorted(os.listdir(image_dir))

            for cls in classes:
                file_names = sorted(os.listdir(os.path.join(image_dir, cls)))
                for file_name in file_names:
                    image_path = os.path.join(image_dir, cls, file_name)

                    if cls == "good":
                        mask_path = ""
                    else:
                        base_name, extension = os.path.splitext(file_name)
                        if dataset == "mvtec":
                            mask_name = f"{base_name}_mask{extension}"
                        elif dataset == "visa":
                            mask_name = f"{base_name}.png"
                        mask_path = os.path.join(mask_dir, cls, mask_name)

                    self.data.append((image_path, mask_path))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # train dataset
        if self.shot > 0:
            image_path = self.data[idx]
            return image_path
        # test dataset
        else:
            image_path, mask_path = self.data[idx]
            return image_path, mask_path
        